{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 7 - Distributional Q-learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "probs = np.array([0.6, 0.1, 0.1, 0.1, 0.1])\n",
    "outcomes = np.array([18, 21, 17, 17, 21])\n",
    "expected_value = 0.0\n",
    "for i in range(probs.shape[0]):\n",
    "    expected_value += probs[i] * outcomes[i]\n",
    "\n",
    "print(expected_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expected_value = probs @ outcomes\n",
    "print(expected_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = 18.4\n",
    "T = lambda: t0 + np.random.randn(1)\n",
    "T()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "vmin,vmax = -10.,10. #A\n",
    "nsup=51 #B\n",
    "support = np.linspace(vmin,vmax,nsup) #C\n",
    "probs = np.ones(nsup)\n",
    "probs /= probs.sum()\n",
    "z3 = torch.from_numpy(probs).float()\n",
    "plt.bar(support,probs) #D"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_dist(r,support,probs,lim=(-10.,10.),gamma=0.8):\n",
    "    nsup = probs.shape[0]\n",
    "    vmin,vmax = lim[0],lim[1]\n",
    "    dz = (vmax-vmin)/(nsup-1.) #A\n",
    "    bj = np.round((r-vmin)/dz) #B\n",
    "    bj = int(np.clip(bj,0,nsup-1)) #C\n",
    "    m = probs.clone()\n",
    "    j = 1\n",
    "    for i in range(bj,1,-1): #D\n",
    "        m[i] += np.power(gamma,j) * m[i-1]\n",
    "        j += 1\n",
    "    j = 1\n",
    "    for i in range(bj,nsup-1,1): #E\n",
    "        m[i] += np.power(gamma,j) * m[i+1]\n",
    "        j += 1\n",
    "    m /= m.sum() #F\n",
    "    return m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ob_reward = -1\n",
    "Z = torch.from_numpy(probs).float()\n",
    "Z = update_dist(ob_reward,torch.from_numpy(support).float(),Z,lim=(vmin,vmax),gamma=0.1)\n",
    "plt.bar(support,Z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ob_rewards = [10,10,10,0,1,0,-10,-10,10,10]\n",
    "for i in range(len(ob_rewards)):\n",
    "    Z = update_dist(ob_rewards[i], torch.from_numpy(support).float(), Z, lim=(vmin,vmax), gamma=0.5)\n",
    "plt.bar(support, Z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ob_rewards = [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]\n",
    "for i in range(len(ob_rewards)):\n",
    "    Z = update_dist(ob_rewards[i], torch.from_numpy(support).float(), \\\n",
    "    Z, lim=(vmin,vmax), gamma=0.7)\n",
    "plt.bar(support, Z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dist_dqn(x,theta,aspace=3): #A\n",
    "    dim0,dim1,dim2,dim3 = 128,100,25,51 #B\n",
    "    t1 = dim0*dim1\n",
    "    t2 = dim2*dim1\n",
    "    theta1 = theta[0:t1].reshape(dim0,dim1) #C\n",
    "    theta2 = theta[t1:t1 + t2].reshape(dim1,dim2)\n",
    "    l1 = x @ theta1 #D\n",
    "    l1 = torch.selu(l1)\n",
    "    l2 = l1 @ theta2 #E\n",
    "    l2 = torch.selu(l2)\n",
    "    l3 = []\n",
    "    for i in range(aspace): #F\n",
    "        step = dim2*dim3\n",
    "        theta5_dim = t1 + t2 + i * step\n",
    "        theta5 = theta[theta5_dim:theta5_dim+step].reshape(dim2,dim3)\n",
    "        l3_ = l2 @ theta5 #G\n",
    "        l3.append(l3_)\n",
    "    l3 = torch.stack(l3,dim=1) #H\n",
    "    l3 = torch.nn.functional.softmax(l3,dim=2)\n",
    "    return l3.squeeze()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_target_dist(dist_batch,action_batch,reward_batch,support,lim=(-10,10),gamma=0.8):\n",
    "    nsup = support.shape[0]\n",
    "    vmin,vmax = lim[0],lim[1]\n",
    "    dz = (vmax-vmin)/(nsup-1.)\n",
    "    target_dist_batch = dist_batch.clone()\n",
    "    for i in range(dist_batch.shape[0]): #A\n",
    "        dist_full = dist_batch[i]\n",
    "        action = int(action_batch[i].item())\n",
    "        dist = dist_full[action]\n",
    "        r = reward_batch[i]\n",
    "        if r != -1: #B\n",
    "            target_dist = torch.zeros(nsup)\n",
    "            bj = np.round((r-vmin)/dz)\n",
    "            bj = int(np.clip(bj,0,nsup-1))\n",
    "            target_dist[bj] = 1.\n",
    "        else: #C\n",
    "            target_dist = update_dist(r,support,dist,lim=lim,gamma=gamma)\n",
    "        target_dist_batch[i,action,:] = target_dist #D\n",
    "        \n",
    "    return target_dist_batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lossfn(x,y):#A\n",
    "    loss = torch.Tensor([0.])\n",
    "    loss.requires_grad=True\n",
    "    for i in range(x.shape[0]): #B \n",
    "        loss_ = -1 *  torch.log(x[i].flatten(start_dim=0)) @ y[i].flatten(start_dim=0) #C\n",
    "        loss = loss + loss_\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "aspace = 3 #A\n",
    "tot_params = 128*100 + 25*100 + aspace*25*51 #B\n",
    "theta = torch.randn(tot_params)/10. #C\n",
    "theta.requires_grad=True\n",
    "theta_2 = theta.detach().clone() #D\n",
    "#\n",
    "vmin,vmax= -10,10\n",
    "gamma=0.9\n",
    "lr = 0.00001\n",
    "update_rate = 75 #E\n",
    "support = torch.linspace(-10,10,51)\n",
    "state = torch.randn(2,128)/10. #F\n",
    "action_batch = torch.Tensor([0,2]) #G\n",
    "reward_batch = torch.Tensor([0,10]) #H\n",
    "losses = [] \n",
    "pred_batch = dist_dqn(state,theta,aspace=aspace) #I\n",
    "target_dist = get_target_dist(pred_batch,action_batch,reward_batch, \\\n",
    "                                 support, lim=(vmin,vmax),gamma=gamma) #J\n",
    "\n",
    "plt.plot((target_dist.flatten(start_dim=1)[0].data.numpy()),color='red',label='target')\n",
    "plt.plot((pred_batch.flatten(start_dim=1)[0].data.numpy()),color='green',label='pred')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(1000):\n",
    "    reward_batch = torch.Tensor([0,8]) + torch.randn(2)/10.0 #A\n",
    "    pred_batch = dist_dqn(state,theta,aspace=aspace) #B\n",
    "    pred_batch2 = dist_dqn(state,theta_2,aspace=aspace) #C\n",
    "    target_dist = get_target_dist(pred_batch2,action_batch,reward_batch, \\\n",
    "                                 support, lim=(vmin,vmax),gamma=gamma) #D\n",
    "    loss = lossfn(pred_batch,target_dist.detach()) #E\n",
    "    losses.append(loss.item())\n",
    "    loss.backward()\n",
    "    # Gradient Descent\n",
    "    with torch.no_grad():\n",
    "        theta -= lr * theta.grad\n",
    "    theta.requires_grad = True\n",
    "    \n",
    "    if i % update_rate == 0: #F \n",
    "        theta_2 = theta.detach().clone()\n",
    "\n",
    "plt.plot((target_dist.flatten(start_dim=1)[0].data.numpy()),color='red',label='target')\n",
    "plt.plot((pred_batch.flatten(start_dim=1)[0].data.numpy()),color='green',label='pred')\n",
    "plt.plot(losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tpred = pred_batch\n",
    "cs = ['gray','green','red']\n",
    "num_batch = 2\n",
    "labels = ['Action {}'.format(i,) for i in range(aspace)]\n",
    "fig,ax = plt.subplots(nrows=num_batch,ncols=aspace)\n",
    "\n",
    "for j in range(num_batch): #A \n",
    "    for i in range(tpred.shape[1]): #B\n",
    "        ax[j,i].bar(support.data.numpy(),tpred[j,i,:].data.numpy(),\\\n",
    "                label='Action {}'.format(i),alpha=0.9,color=cs[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preproc_state(state):\n",
    "    p_state = torch.from_numpy(state).unsqueeze(dim=0).float()\n",
    "    p_state = torch.nn.functional.normalize(p_state,dim=1) #A\n",
    "    return p_state\n",
    "\n",
    "def get_action(dist,support):\n",
    "    actions = []\n",
    "    for b in range(dist.shape[0]): #B\n",
    "        expectations = [support @ dist[b,a,:] for a in range(dist.shape[1])] #C\n",
    "        action = int(np.argmax(expectations)) #D\n",
    "        actions.append(action)\n",
    "    actions = torch.Tensor(actions).int()\n",
    "    return actions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.13"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "from collections import deque\n",
    "env = gym.make('Freeway-ram-v0')\n",
    "aspace = 3\n",
    "env.env.get_action_meanings()\n",
    "\n",
    "vmin,vmax = -10,10\n",
    "replay_size = 200\n",
    "batch_size = 50\n",
    "nsup = 51\n",
    "dz = (vmax - vmin) / (nsup-1)\n",
    "support = torch.linspace(vmin,vmax,nsup)\n",
    "\n",
    "replay = deque(maxlen=replay_size) #A\n",
    "lr = 0.0001 #B \n",
    "gamma = 0.1 #C \n",
    "epochs = 1300\n",
    "eps = 0.20 #D starting epsilon for epsilon-greedy policy\n",
    "eps_min = 0.05 #E ending epsilon\n",
    "priority_level = 5 #F \n",
    "update_freq = 25 #G \n",
    "\n",
    "#Initialize DQN parameter vector\n",
    "tot_params = 128*100 + 25*100 + aspace*25*51  #H \n",
    "theta = torch.randn(tot_params)/10. #I \n",
    "theta.requires_grad=True\n",
    "theta_2 = theta.detach().clone() #J \n",
    "\n",
    "losses = []\n",
    "cum_rewards = [] #K \n",
    "renders = []\n",
    "state = preproc_state(env.reset())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 7.14"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import shuffle\n",
    "for i in range(epochs):\n",
    "    pred = dist_dqn(state,theta,aspace=aspace)\n",
    "    if i < replay_size or np.random.rand(1) < eps: #A\n",
    "        action = np.random.randint(aspace)\n",
    "    else:\n",
    "        action = get_action(pred.unsqueeze(dim=0).detach(),support).item()\n",
    "    state2, reward, done, info = env.step(action) #B\n",
    "    state2 = preproc_state(state2)\n",
    "    if reward == 1: cum_rewards.append(1) \n",
    "    reward = 10 if reward == 1 else reward #C\n",
    "    reward = -10 if done else reward #D\n",
    "    reward = -1 if reward == 0 else reward #E\n",
    "    exp = (state,action,reward,state2) #F\n",
    "    replay.append(exp) #G\n",
    "    \n",
    "    if reward == 10: #H\n",
    "        for e in range(priority_level):\n",
    "            replay.append(exp)\n",
    "            \n",
    "    shuffle(replay)\n",
    "    state = state2\n",
    "\n",
    "    if len(replay) == replay_size: #I\n",
    "        indx = np.random.randint(low=0,high=len(replay),size=batch_size)\n",
    "        exps = [replay[j] for j in indx]\n",
    "        state_batch = torch.stack([ex[0] for ex in exps],dim=1).squeeze()\n",
    "        action_batch = torch.Tensor([ex[1] for ex in exps])\n",
    "        reward_batch = torch.Tensor([ex[2] for ex in exps])\n",
    "        state2_batch = torch.stack([ex[3] for ex in exps],dim=1).squeeze()\n",
    "        pred_batch = dist_dqn(state_batch.detach(),theta,aspace=aspace)\n",
    "        pred2_batch = dist_dqn(state2_batch.detach(),theta_2,aspace=aspace)\n",
    "        target_dist = get_target_dist(pred2_batch,action_batch,reward_batch, \\\n",
    "                                     support, lim=(vmin,vmax),gamma=gamma)\n",
    "        loss = lossfn(pred_batch,target_dist.detach())\n",
    "        losses.append(loss.item())\n",
    "        loss.backward()\n",
    "        with torch.no_grad(): #J\n",
    "            theta -= lr * theta.grad\n",
    "        theta.requires_grad = True\n",
    "        \n",
    "    if i % update_freq == 0: #K\n",
    "        theta_2 = theta.detach().clone()\n",
    "        \n",
    "    if i > 100 and eps > eps_min: #L\n",
    "        dec = 1./np.log2(i)\n",
    "        dec /= 1e3\n",
    "        eps -= dec\n",
    "    \n",
    "    if done: #M\n",
    "        state = preproc_state(env.reset())\n",
    "        done = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(losses)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:deeprl]",
   "language": "python",
   "name": "conda-env-deeprl-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
